import cdtools
from scipy.io import loadmat, savemat
import numpy as np
import torch as t
from helper_functions import center_probe, center_probe_fourier, split_dataset

view = True # Whether or not to plot the results as the calculation proceeds
if view:
    from matplotlib import pyplot as plt

device = 'cuda' # The device to run the reconstruction on

# Load the data
data_filename = 'data/single_shot_ptycho.cxi'
dataset = cdtools.datasets.Ptycho2DDataset.from_cxi(data_filename,
                                                    cut_zeros=False)

# We add 10 to the dataset, to enable our algorithm to find the
# low-intensity flat-field which otherwise goes negative
dataset.patterns = t.clamp(dataset.patterns + 10, min=0)

# We split the dataset into two disjoint halves to calculate an FRC
dataset_1, dataset_2 = split_dataset(dataset)
dataset_full = dataset

# We run the same reconstruction on each half and the full dataset
for plan, dataset in [('half_1', dataset_1),
                      ('half_2', dataset_2),
                      ('full', dataset_full)]:

    savefile = f'results/single_shot_ptycho_{plan}_rough.mat'
    print('To save in', savefile)
        
    # First, we create a ptychography model from the dataset
    model = cdtools.models.FancyPtycho.from_dataset(
        dataset,
        n_modes=3,
        propagation_distance=275e-6,
        translation_scale=5,
        probe_support_radius=350,
        simulate_probe_translation=False,
        obj_view_crop=-200,
        units='um',
    )

    # Move the model and dataset to the relevant device
    model.to(device=device)
    dataset.get_as(device=device)

    if view:
        model.inspect(dataset)

    # We initialize the background estimate to the offset level of 10 ADU
    # This offset level was introduced earlier to deal with a slightly
    # nonuniform background that dips below zero in some detector regions.
    # This background is then estimated as part of the ptychography
    # reconstruction.
    model.background.data[:] = 10
    
    # We start the reconstruction with position annealing turned off
    model.translation_offsets.requires_grad=False
    # Same for the shot-to-shot intensity correction
    model.weights.requires_grad=False

    # Now we run an aggressive reconstruction step to recover the probe
    for loss in model.Adam_optimize(50, dataset, lr=0.05, batch_size=5, schedule=False):
        print(model.report())
        if model.epoch % 5 == 0 and view:
            model.inspect(dataset)

    # Next we center this probe, which usually appears off-center
    model.probe.data = center_probe(model.probe.detach().cpu()).to(
        device=device)
    model.probe.data *= model.probe_support
    # And we reset the object after moving the probe
    model.obj.data[:] = 1
    model.probe_support[:] = 1

    # Now, we fix the probe for the first few iterations of the next
    # reconstruction round
    model.probe.requires_grad=False
    # We turn on the shot-to-shot intensity correction
    model.weights.requires_grad=True
    # And same for the position correction
    model.translation_offsets.requires_grad=True
    
    for loss in model.Adam_optimize(25, dataset, lr=0.02, batch_size=25, schedule=False):
        print(model.report())
        if model.epoch % 5 == 0 and view:
            model.inspect(dataset)

    # And then we run a few more iterations to allow the probe to converge
    model.probe.requires_grad=True

    for loss in model.Adam_optimize(50, dataset, lr=0.02, batch_size=5, schedule=False):
        print(model.report())
        if model.epoch % 5 == 0 and view:
            model.inspect(dataset)

    # This orthogonalizes the probe reconstruction
    model.tidy_probes()

    # We save out the results
    savemat(savefile, model.save_results(dataset))
    
    # Finally, we plot the results
    if view:
        model.inspect(dataset)
        model.compare(dataset)

if view:
    plt.show()
